{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "wandb not available\n",
      "wandb not available\n"
     ]
    }
   ],
   "source": [
    "import ssms\n",
    "import lanfactory\n",
    "import os\n",
    "import numpy as np\n",
    "from copy import deepcopy\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL = \"angle\"\n",
    "RUN_SIMS = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize the generator config (for MLP LANs)\n",
    "generator_config = deepcopy(ssms.config.data_generator_config[\"lan\"])\n",
    "# Specify generative model (one from the list of included models mentioned above)\n",
    "generator_config[\"dgp_list\"] = MODEL\n",
    "# Specify number of parameter sets to simulate\n",
    "generator_config[\"n_parameter_sets\"] = 256\n",
    "# Specify how many samples a simulation run should entail\n",
    "generator_config[\"n_samples\"] = 2000\n",
    "# Specify folder in which to save generated data\n",
    "generator_config[\"output_folder\"] = \"data/lan_mlp/\"\n",
    "\n",
    "# Make model config dict\n",
    "model_config = ssms.config.model_config[MODEL]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'name': 'angle',\n",
       " 'params': ['v', 'a', 'z', 't', 'theta'],\n",
       " 'param_bounds': [[-3.0, 0.3, 0.1, 0.001, -0.1], [3.0, 3.0, 0.9, 2.0, 1.3]],\n",
       " 'boundary': <function ssms.basic_simulators.boundary_functions.angle(t=1, theta=1)>,\n",
       " 'n_params': 5,\n",
       " 'default_params': [0.0, 1.0, 0.5, 0.001, 0.0],\n",
       " 'hddm_include': ['z', 'theta'],\n",
       " 'nchoices': 2}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "generator_config[\"output_folder\"] = (\n",
    "    \"data/lan_mlp/\"\n",
    "    + generator_config[\"dgp_list\"]\n",
    "    + \"/\"\n",
    "    + str(generator_config[\"n_samples\"])\n",
    "    + \"_\"\n",
    "    + str(generator_config[\"n_training_samples_by_parameter_set\"])\n",
    "    + \"/\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "if RUN_SIMS:\n",
    "    n_datafiles = 20\n",
    "    for i in range(n_datafiles):\n",
    "        print(\"Datafile: \", i)\n",
    "        my_dataset_generator = ssms.dataset_generators.lan_mlp.data_generator(\n",
    "            generator_config=generator_config, model_config=model_config\n",
    "        )\n",
    "        training_data = my_dataset_generator.generate_data_training_uniform(save=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "my_data = pickle.load(\n",
    "    open(\n",
    "        \"data/lan_mlp/angle/2000_1000/training_data_77bd5b0073ac11ee9b2f0242ac110002.pickle\",\n",
    "        \"rb\",\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([  3.1092792,   3.1623678,   3.1580684, ..., -66.77497  ,\n",
       "       -66.77497  , -66.77497  ], dtype=float32)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "my_data[\"labels\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Network config: \n",
      "{'layer_sizes': [100, 100, 1], 'activations': ['tanh', 'tanh', 'linear'], 'train_output_type': 'logprob'}\n",
      "Train config: \n",
      "{'cpu_batch_size': 128, 'gpu_batch_size': 256, 'n_epochs': 5, 'optimizer': 'adam', 'learning_rate': 0.002, 'lr_scheduler': 'reduce_on_plateau', 'lr_scheduler_params': {}, 'weight_decay': 0.0, 'loss': 'huber', 'save_history': True}\n"
     ]
    }
   ],
   "source": [
    "network_config = lanfactory.config.network_configs.network_config_mlp\n",
    "\n",
    "print(\"Network config: \")\n",
    "print(network_config)\n",
    "\n",
    "train_config = lanfactory.config.network_configs.train_config_mlp\n",
    "\n",
    "print(\"Train config: \")\n",
    "print(train_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_config[\"cpu_batch_size\"] = 128\n",
    "train_config[\"gpu_batch_size\"] = 2048\n",
    "train_config[\"n_epochs\"] = 20"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "folder_ = \"data/lan_mlp/\" + MODEL + \"/2000_1000/\"\n",
    "file_list_ = [folder_ + file_ for file_ in os.listdir(folder_)]\n",
    "\n",
    "# Training dataset\n",
    "jax_training_dataset = lanfactory.trainers.DatasetTorch(\n",
    "    file_ids=file_list_,\n",
    "    batch_size=train_config[\"gpu_batch_size\"]\n",
    "    if torch.cuda.is_available()\n",
    "    else train_config[\"cpu_batch_size\"],\n",
    "    label_lower_bound=np.log(1e-10),\n",
    "    features_key=\"data\",\n",
    "    label_key=\"labels\",\n",
    "    out_framework=\"jax\",\n",
    ")\n",
    "\n",
    "jax_training_dataloader = torch.utils.data.DataLoader(\n",
    "    jax_training_dataset, shuffle=True, batch_size=None, num_workers=1, pin_memory=True\n",
    ")\n",
    "\n",
    "# Validation dataset\n",
    "jax_validation_dataset = lanfactory.trainers.DatasetTorch(\n",
    "    file_ids=file_list_,\n",
    "    batch_size=train_config[\"gpu_batch_size\"]\n",
    "    if torch.cuda.is_available()\n",
    "    else train_config[\"cpu_batch_size\"],\n",
    "    label_lower_bound=np.log(1e-10),\n",
    "    features_key=\"data\",\n",
    "    label_key=\"labels\",\n",
    "    out_framework=\"jax\",\n",
    ")\n",
    "\n",
    "jax_validation_dataloader = torch.utils.data.DataLoader(\n",
    "    jax_validation_dataset,\n",
    "    shuffle=True,\n",
    "    batch_size=None,\n",
    "    num_workers=1,\n",
    "    pin_memory=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-1.7034,  1.1065,  0.7065,  ...,  1.0142,  0.8140,  1.0000],\n",
      "        [-1.2475,  0.4886,  0.2193,  ..., -0.0594,  1.2449, -1.0000],\n",
      "        [-0.7896,  1.8203,  0.4408,  ...,  1.2277,  1.7719,  1.0000],\n",
      "        ...,\n",
      "        [-1.7004,  1.8325,  0.7201,  ..., -0.0194, -0.5118, -1.0000],\n",
      "        [ 1.1699,  1.0593,  0.1587,  ...,  0.7719,  0.5738, -1.0000],\n",
      "        [-2.7457,  2.1202,  0.7565,  ...,  0.1959,  2.1654, -1.0000]])\n",
      "tensor([[ -0.0369],\n",
      "        [  0.2909],\n",
      "        [ -0.2410],\n",
      "        ...,\n",
      "        [-23.0259],\n",
      "        [  0.9759],\n",
      "        [  0.1988]])\n"
     ]
    }
   ],
   "source": [
    "cnt = 0\n",
    "for xb, yb in jax_training_dataloader:\n",
    "    print(xb)\n",
    "    print(yb)\n",
    "    cnt += 1\n",
    "    if cnt > 0:\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# LOAD NETWORK\n",
    "jax_net = lanfactory.trainers.MLPJaxFactory(network_config=network_config, train=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "jax_trainer = lanfactory.trainers.ModelTrainerJaxMLP(\n",
    "    train_config=train_config,\n",
    "    model=jax_net,\n",
    "    train_dl=jax_training_dataloader,\n",
    "    valid_dl=jax_validation_dataloader,\n",
    "    pin_memory=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found folder:  data\n",
      "Moving on...\n",
      "Found folder:  data/trained_model\n",
      "Moving on...\n",
      "Found folder:  data/trained_model/jax\n",
      "Moving on...\n",
      "Epoch: 0 of 20\n",
      "Training - Step: 0 of 2440 - Loss: 4.266737\n",
      "Training - Step: 1000 of 2440 - Loss: 0.20318128\n",
      "Training - Step: 2000 of 2440 - Loss: 0.09071794\n",
      "Epoch 0/20 time: 5.18427848815918s\n",
      "Validation - Step: 0 of 2440 - Loss: 0.13604295\n",
      "Validation - Step: 1000 of 2440 - Loss: 0.10391633\n",
      "Validation - Step: 2000 of 2440 - Loss: 0.109466106\n",
      "Epoch 0/20 time: 4.734684944152832s\n",
      "Epoch: 0 / 20, test_loss: 0.1084057167172432\n",
      "Epoch: 1 of 20\n",
      "Training - Step: 0 of 2440 - Loss: 0.09565063\n",
      "Training - Step: 1000 of 2440 - Loss: 0.07101747\n",
      "Training - Step: 2000 of 2440 - Loss: 0.07452312\n",
      "Epoch 1/20 time: 4.704767227172852s\n",
      "Validation - Step: 0 of 2440 - Loss: 0.06880751\n",
      "Validation - Step: 1000 of 2440 - Loss: 0.08170596\n",
      "Validation - Step: 2000 of 2440 - Loss: 0.074142\n",
      "Epoch 1/20 time: 4.5430285930633545s\n",
      "Epoch: 1 / 20, test_loss: 0.07345124334096909\n",
      "Epoch: 2 of 20\n",
      "Training - Step: 0 of 2440 - Loss: 0.046777636\n",
      "Training - Step: 1000 of 2440 - Loss: 0.05197097\n",
      "Training - Step: 2000 of 2440 - Loss: 0.07017851\n",
      "Epoch 2/20 time: 4.994797229766846s\n",
      "Validation - Step: 0 of 2440 - Loss: 0.07391778\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Exception in thread Thread-11 (_pin_memory_loop):\n",
      "Traceback (most recent call last):\n",
      "  File \"/usr/lib/python3.10/threading.py\", line 1016, in _bootstrap_inner\n",
      "    self.run()\n",
      "  File \"/usr/lib/python3.10/threading.py\", line 953, in run\n",
      "    self._target(*self._args, **self._kwargs)\n",
      "  File \"/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/pin_memory.py\", line 51, in _pin_memory_loop\n",
      "    do_one_step()\n",
      "  File \"/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/pin_memory.py\", line 28, in do_one_step\n",
      "    r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)\n",
      "  File \"/usr/lib/python3.10/multiprocessing/queues.py\", line 122, in get\n",
      "    return _ForkingPickler.loads(res)\n",
      "  File \"/usr/local/lib/python3.10/dist-packages/torch/multiprocessing/reductions.py\", line 307, in rebuild_storage_fd\n",
      "    fd = df.detach()\n",
      "  File \"/usr/lib/python3.10/multiprocessing/resource_sharer.py\", line 58, in detach\n",
      "    return reduction.recv_handle(conn)\n",
      "  File \"/usr/lib/python3.10/multiprocessing/reduction.py\", line 189, in recv_handle\n",
      "    return recvfds(s, 1)[0]\n",
      "  File \"/usr/lib/python3.10/multiprocessing/reduction.py\", line 157, in recvfds\n",
      "    msg, ancdata, flags, addr = sock.recvmsg(1, socket.CMSG_SPACE(bytes_size))\n",
      "ConnectionResetError: [Errno 104] Connection reset by peer\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[15], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m train_state \u001b[38;5;241m=\u001b[39m \u001b[43mjax_trainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_and_evaluate\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m      2\u001b[0m \u001b[43m    \u001b[49m\u001b[43moutput_folder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mdata/trained_model/jax/\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m      3\u001b[0m \u001b[43m    \u001b[49m\u001b[43moutput_file_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mMODEL\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m      4\u001b[0m \u001b[43m    \u001b[49m\u001b[43mrun_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtest_run_notebook\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m      5\u001b[0m \u001b[43m    \u001b[49m\u001b[43mwandb_on\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m      6\u001b[0m \u001b[43m    \u001b[49m\u001b[43mwandb_project_id\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtest_run_notebook\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m      7\u001b[0m \u001b[43m    \u001b[49m\u001b[43msave_data_details\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m      8\u001b[0m \u001b[43m    \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m      9\u001b[0m \u001b[43m    \u001b[49m\u001b[43msave_all\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m     10\u001b[0m \u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/lanfactory/trainers/jax_mlp.py:587\u001b[0m, in \u001b[0;36mModelTrainerJaxMLP.train_and_evaluate\u001b[0;34m(self, output_folder, output_file_id, run_id, wandb_on, wandb_project_id, save_history, save_model, save_config, save_all, save_data_details, verbose)\u001b[0m\n\u001b[1;32m    578\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEpoch: \u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mstr\u001b[39m(epoch) \u001b[38;5;241m+\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m of \u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mstr\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrain_config[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mn_epochs\u001b[39m\u001b[38;5;124m\"\u001b[39m]))\n\u001b[1;32m    579\u001b[0m state, train_loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrun_epoch(\n\u001b[1;32m    580\u001b[0m     state,\n\u001b[1;32m    581\u001b[0m     train\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    584\u001b[0m     max_epochs\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrain_config[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mn_epochs\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n\u001b[1;32m    585\u001b[0m )\n\u001b[0;32m--> 587\u001b[0m state, test_loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_epoch\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    588\u001b[0m \u001b[43m    \u001b[49m\u001b[43mstate\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    589\u001b[0m \u001b[43m    \u001b[49m\u001b[43mtrain\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m    590\u001b[0m \u001b[43m    \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverbose\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    591\u001b[0m \u001b[43m    \u001b[49m\u001b[43mepoch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mepoch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    592\u001b[0m \u001b[43m    \u001b[49m\u001b[43mmax_epochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_config\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mn_epochs\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    593\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    595\u001b[0m \u001b[38;5;66;03m# Collect loss in training history\u001b[39;00m\n\u001b[1;32m    596\u001b[0m training_history\u001b[38;5;241m.\u001b[39mvalues[epoch, :] \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mint\u001b[39m(epoch), \u001b[38;5;28mfloat\u001b[39m(test_loss)]\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/lanfactory/trainers/jax_mlp.py:433\u001b[0m, in \u001b[0;36mModelTrainerJaxMLP.run_epoch\u001b[0;34m(self, state, train, verbose, epoch, max_epochs)\u001b[0m\n\u001b[1;32m    431\u001b[0m step \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m    432\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m X, y \u001b[38;5;129;01min\u001b[39;00m tmp_dataloader:\n\u001b[0;32m--> 433\u001b[0m     X_jax \u001b[38;5;241m=\u001b[39m \u001b[43mjnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43marray\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    434\u001b[0m     y_jax \u001b[38;5;241m=\u001b[39m jnp\u001b[38;5;241m.\u001b[39marray(y)\n\u001b[1;32m    436\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m train:\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py:2089\u001b[0m, in \u001b[0;36marray\u001b[0;34m(object, dtype, copy, order, ndmin)\u001b[0m\n\u001b[1;32m   2085\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m array(np\u001b[38;5;241m.\u001b[39masarray(view), dtype, copy, ndmin\u001b[38;5;241m=\u001b[39mndmin)\n\u001b[1;32m   2087\u001b[0m   \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnexpected input type for array: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mobject\u001b[39m)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m-> 2089\u001b[0m out_array: Array \u001b[38;5;241m=\u001b[39m \u001b[43mlax_internal\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_convert_element_type\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   2090\u001b[0m \u001b[43m    \u001b[49m\u001b[43mout\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweak_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mweak_type\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   2091\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m ndmin \u001b[38;5;241m>\u001b[39m ndim(out_array):\n\u001b[1;32m   2092\u001b[0m   out_array \u001b[38;5;241m=\u001b[39m lax\u001b[38;5;241m.\u001b[39mexpand_dims(out_array, \u001b[38;5;28mrange\u001b[39m(ndmin \u001b[38;5;241m-\u001b[39m ndim(out_array)))\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/lax/lax.py:555\u001b[0m, in \u001b[0;36m_convert_element_type\u001b[0;34m(operand, new_dtype, weak_type)\u001b[0m\n\u001b[1;32m    553\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m type_cast(Array, operand)\n\u001b[1;32m    554\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 555\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mconvert_element_type_p\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbind\u001b[49m\u001b[43m(\u001b[49m\u001b[43moperand\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnew_dtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnew_dtype\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    556\u001b[0m \u001b[43m                                     \u001b[49m\u001b[43mweak_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mbool\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mweak_type\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/core.py:385\u001b[0m, in \u001b[0;36mPrimitive.bind\u001b[0;34m(self, *args, **params)\u001b[0m\n\u001b[1;32m    382\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mbind\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mparams):\n\u001b[1;32m    383\u001b[0m   \u001b[38;5;28;01massert\u001b[39;00m (\u001b[38;5;129;01mnot\u001b[39;00m config\u001b[38;5;241m.\u001b[39menable_checks\u001b[38;5;241m.\u001b[39mvalue \u001b[38;5;129;01mor\u001b[39;00m\n\u001b[1;32m    384\u001b[0m           \u001b[38;5;28mall\u001b[39m(\u001b[38;5;28misinstance\u001b[39m(arg, Tracer) \u001b[38;5;129;01mor\u001b[39;00m valid_jaxtype(arg) \u001b[38;5;28;01mfor\u001b[39;00m arg \u001b[38;5;129;01min\u001b[39;00m args)), args\n\u001b[0;32m--> 385\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbind_with_trace\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfind_top_trace\u001b[49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/core.py:388\u001b[0m, in \u001b[0;36mPrimitive.bind_with_trace\u001b[0;34m(self, trace, args, params)\u001b[0m\n\u001b[1;32m    387\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mbind_with_trace\u001b[39m(\u001b[38;5;28mself\u001b[39m, trace, args, params):\n\u001b[0;32m--> 388\u001b[0m   out \u001b[38;5;241m=\u001b[39m \u001b[43mtrace\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprocess_primitive\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mmap\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mtrace\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfull_raise\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    389\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mmap\u001b[39m(full_lower, out) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmultiple_results \u001b[38;5;28;01melse\u001b[39;00m full_lower(out)\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/core.py:868\u001b[0m, in \u001b[0;36mEvalTrace.process_primitive\u001b[0;34m(self, primitive, tracers, params)\u001b[0m\n\u001b[1;32m    867\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mprocess_primitive\u001b[39m(\u001b[38;5;28mself\u001b[39m, primitive, tracers, params):\n\u001b[0;32m--> 868\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mprimitive\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mimpl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mtracers\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py:140\u001b[0m, in \u001b[0;36mapply_primitive\u001b[0;34m(prim, *args, **params)\u001b[0m\n\u001b[1;32m    136\u001b[0m   msg \u001b[38;5;241m=\u001b[39m pjit\u001b[38;5;241m.\u001b[39m_device_assignment_mismatch_error(\n\u001b[1;32m    137\u001b[0m       prim\u001b[38;5;241m.\u001b[39mname, fails, args, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mjit\u001b[39m\u001b[38;5;124m'\u001b[39m, arg_names)\n\u001b[1;32m    138\u001b[0m   \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(msg) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 140\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mcompiled_fun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py:172\u001b[0m, in \u001b[0;36mxla_primitive_callable.<locals>.<lambda>\u001b[0;34m(*args, **kw)\u001b[0m\n\u001b[1;32m    170\u001b[0m   call \u001b[38;5;241m=\u001b[39m compiled\u001b[38;5;241m.\u001b[39munsafe_call\n\u001b[1;32m    171\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m prim\u001b[38;5;241m.\u001b[39mmultiple_results:\n\u001b[0;32m--> 172\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mlambda\u001b[39;00m \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkw: \u001b[43mcall\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkw\u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m    173\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    174\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m call\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "train_state = jax_trainer.train_and_evaluate(\n",
    "    output_folder=\"data/trained_model/jax/\",\n",
    "    output_file_id=MODEL,\n",
    "    run_id=\"test_run_notebook\",\n",
    "    wandb_on=False,\n",
    "    wandb_project_id=\"test_run_notebook\",\n",
    "    save_data_details=True,\n",
    "    verbose=1,\n",
    "    save_all=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\n",
      "Requirement already satisfied: jax[cuda12_pip] in /usr/local/lib/python3.10/dist-packages (0.4.19)\n",
      "\u001b[33mWARNING: jax 0.4.19 does not provide the extra 'cuda12-pip'\u001b[0m\u001b[33m\n",
      "\u001b[0mRequirement already satisfied: ml-dtypes>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from jax[cuda12_pip]) (0.3.1)\n",
      "Requirement already satisfied: numpy>=1.22 in /usr/local/lib/python3.10/dist-packages (from jax[cuda12_pip]) (1.25.2)\n",
      "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.10/dist-packages (from jax[cuda12_pip]) (3.3.0)\n",
      "Requirement already satisfied: scipy>=1.9 in /usr/local/lib/python3.10/dist-packages (from jax[cuda12_pip]) (1.11.1)\n",
      "Collecting jaxlib==0.4.19+cuda12.cudnn89 (from jax[cuda12_pip])\n",
      "  Downloading https://storage.googleapis.com/jax-releases/cuda12/jaxlib-0.4.19%2Bcuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl (138.0 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m138.0/138.0 MB\u001b[0m \u001b[31m5.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
      "\u001b[?25hCollecting nvidia-cublas-cu12>=12.2.5.6 (from jax[cuda12_pip])\n",
      "  Downloading nvidia_cublas_cu12-12.3.2.9-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n",
      "Collecting nvidia-cuda-cupti-cu12>=12.2.142 (from jax[cuda12_pip])\n",
      "  Downloading nvidia_cuda_cupti_cu12-12.3.52-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\n",
      "Collecting nvidia-cuda-nvcc-cu12>=12.2.140 (from jax[cuda12_pip])\n",
      "  Downloading nvidia_cuda_nvcc_cu12-12.3.52-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n",
      "Collecting nvidia-cuda-runtime-cu12>=12.2.140 (from jax[cuda12_pip])\n",
      "  Downloading nvidia_cuda_runtime_cu12-12.3.52-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n",
      "Collecting nvidia-cudnn-cu12>=8.9 (from jax[cuda12_pip])\n",
      "  Downloading nvidia_cudnn_cu12-8.9.4.25-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\n",
      "Collecting nvidia-cufft-cu12>=11.0.8.103 (from jax[cuda12_pip])\n",
      "  Downloading nvidia_cufft_cu12-11.0.11.19-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n",
      "Collecting nvidia-cusolver-cu12>=11.5.2 (from jax[cuda12_pip])\n",
      "  Downloading nvidia_cusolver_cu12-11.5.3.52-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\n",
      "Collecting nvidia-cusparse-cu12>=12.1.2.141 (from jax[cuda12_pip])\n",
      "  Downloading nvidia_cusparse_cu12-12.1.3.153-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\n",
      "Collecting nvidia-nccl-cu12>=2.18.3 (from jax[cuda12_pip])\n",
      "  Downloading nvidia_nccl_cu12-2.19.3-py3-none-manylinux1_x86_64.whl.metadata (1.8 kB)\n",
      "Collecting nvidia-nvjitlink-cu12>=12.2 (from jax[cuda12_pip])\n",
      "  Downloading nvidia_nvjitlink_cu12-12.3.52-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n",
      "Collecting nvidia-cuda-nvrtc-cu12 (from nvidia-cudnn-cu12>=8.9->jax[cuda12_pip])\n",
      "  Downloading nvidia_cuda_nvrtc_cu12-12.3.52-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n",
      "Downloading nvidia_cublas_cu12-12.3.2.9-py3-none-manylinux1_x86_64.whl (417.9 MB)\n",
      "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m417.9/417.9 MB\u001b[0m \u001b[31m4.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:02\u001b[0mm\n",
      "\u001b[?25hDownloading nvidia_cuda_cupti_cu12-12.3.52-py3-none-manylinux1_x86_64.whl (14.0 MB)\n",
      "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m14.0/14.0 MB\u001b[0m \u001b[31m9.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m:00:01\u001b[0m\n",
      "\u001b[?25hDownloading nvidia_cuda_nvcc_cu12-12.3.52-py3-none-manylinux1_x86_64.whl (22.0 MB)\n",
      "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m22.0/22.0 MB\u001b[0m \u001b[31m6.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0mm\n",
      "\u001b[?25hDownloading nvidia_cuda_runtime_cu12-12.3.52-py3-none-manylinux1_x86_64.whl (867 kB)\n",
      "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m867.7/867.7 kB\u001b[0m \u001b[31m2.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m\n",
      "\u001b[?25hDownloading nvidia_cudnn_cu12-8.9.4.25-py3-none-manylinux1_x86_64.whl (720.1 MB)\n",
      "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m720.1/720.1 MB\u001b[0m \u001b[31m1.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0mm0:03\u001b[0mm\n",
      "\u001b[?25hDownloading nvidia_cufft_cu12-11.0.11.19-py3-none-manylinux1_x86_64.whl (98.8 MB)\n",
      "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m98.8/98.8 MB\u001b[0m \u001b[31m6.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0mm\n",
      "\u001b[?25hDownloading nvidia_cusolver_cu12-11.5.3.52-py3-none-manylinux1_x86_64.whl (125.2 MB)\n",
      "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m125.2/125.2 MB\u001b[0m \u001b[31m6.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
      "\u001b[?25hDownloading nvidia_cusparse_cu12-12.1.3.153-py3-none-manylinux1_x86_64.whl (195.6 MB)\n",
      "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m195.6/195.6 MB\u001b[0m \u001b[31m5.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0mm\n",
      "\u001b[?25hDownloading nvidia_nccl_cu12-2.19.3-py3-none-manylinux1_x86_64.whl (166.0 MB)\n",
      "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m166.0/166.0 MB\u001b[0m \u001b[31m5.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0mm\n",
      "\u001b[?25hDownloading nvidia_nvjitlink_cu12-12.3.52-py3-none-manylinux1_x86_64.whl (20.5 MB)\n",
      "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m20.5/20.5 MB\u001b[0m \u001b[31m8.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0mm\n",
      "\u001b[?25hDownloading nvidia_cuda_nvrtc_cu12-12.3.52-py3-none-manylinux1_x86_64.whl (24.9 MB)\n",
      "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.9/24.9 MB\u001b[0m \u001b[31m6.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
      "\u001b[?25hInstalling collected packages: nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-nvcc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, nvidia-cusparse-cu12, nvidia-cudnn-cu12, jaxlib, nvidia-cusolver-cu12\n",
      "  Attempting uninstall: jaxlib\n",
      "    Found existing installation: jaxlib 0.4.19\n",
      "    Uninstalling jaxlib-0.4.19:\n",
      "      Successfully uninstalled jaxlib-0.4.19\n",
      "Successfully installed jaxlib-0.4.19+cuda12.cudnn89 nvidia-cublas-cu12-12.3.2.9 nvidia-cuda-cupti-cu12-12.3.52 nvidia-cuda-nvcc-cu12-12.3.52 nvidia-cuda-nvrtc-cu12-12.3.52 nvidia-cuda-runtime-cu12-12.3.52 nvidia-cudnn-cu12-8.9.4.25 nvidia-cufft-cu12-11.0.11.19 nvidia-cusolver-cu12-11.5.3.52 nvidia-cusparse-cu12-12.1.3.153 nvidia-nccl-cu12-2.19.3 nvidia-nvjitlink-cu12-12.3.52\n",
      "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
      "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n"
     ]
    }
   ],
   "source": [
    "%pip install -U \"jax[cuda12_pip]\" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Loaded Net\n",
    "jax_infer = lanfactory.trainers.MLPJaxFactory(\n",
    "    network_config=\"data/torch_models/angle/angle_torch_network_config.pickle\",\n",
    "    train=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "forward_pass, forward_pass_jitted = jax_infer.make_forward_partial(\n",
    "    seed=42,\n",
    "    input_dim=model_config[\"n_params\"] + 2,\n",
    "    state=\"data/trained_model/jax/test_run_notebook_lan_angle__train_state.jax\",\n",
    "    add_jitted=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax.numpy as jnp\n",
    "\n",
    "# Test parameters:\n",
    "v, a, z, t, theta = 0.5, 1.5, 0.5, 0.3, 0.3\n",
    "\n",
    "# Comparison simulator run\n",
    "sim_out = ssms.basic_simulators.simulator.simulator(\n",
    "    model=MODEL, theta=[v, a, z, t, theta], n_samples=50000\n",
    ")\n",
    "\n",
    "# Make input matric\n",
    "input_mat = jnp.zeros((2000, 7))\n",
    "input_mat = input_mat.at[:, 0].set(jnp.ones(2000) * v)\n",
    "input_mat = input_mat.at[:, 1].set(jnp.ones(2000) * a)\n",
    "input_mat = input_mat.at[:, 2].set(jnp.ones(2000) * z)\n",
    "input_mat = input_mat.at[:, 3].set(jnp.ones(2000) * t)\n",
    "input_mat = input_mat.at[:, 4].set(jnp.ones(2000) * theta)\n",
    "input_mat = input_mat.at[:, 5].set(\n",
    "    jnp.array(\n",
    "        np.concatenate(\n",
    "            [\n",
    "                np.linspace(5, 0, 1000).astype(np.float32),\n",
    "                np.linspace(0, 5, 1000).astype(np.float32),\n",
    "            ]\n",
    "        )\n",
    "    )\n",
    ")\n",
    "input_mat = input_mat.at[:, 6].set(\n",
    "    jnp.array(\n",
    "        np.concatenate([np.repeat(-1.0, 1000), np.repeat(1.0, 1000)]).astype(np.float32)\n",
    "    )\n",
    ")\n",
    "\n",
    "net_out = forward_pass_jitted(input_mat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array([[-10.9201765],\n",
       "       [-10.892712 ],\n",
       "       [-10.864927 ],\n",
       "       ...,\n",
       "       [-10.891985 ],\n",
       "       [-10.910214 ],\n",
       "       [-10.928265 ]], dtype=float32)"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net_out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([4.69769437e-04, 8.22096514e-04, 1.64419303e-03, 5.05002144e-03,\n",
       "        1.11570241e-02, 1.56198338e-02, 2.24314906e-02, 3.03001287e-02,\n",
       "        4.07524986e-02, 5.19095227e-02, 6.55328364e-02, 8.18573243e-02,\n",
       "        8.94910777e-02, 1.05345796e-01, 1.12509780e-01, 1.30478461e-01,\n",
       "        1.36233137e-01, 1.41165716e-01, 1.48916911e-01, 1.36468021e-01,\n",
       "        9.31317908e-02, 1.98477587e-02, 0.00000000e+00, 0.00000000e+00,\n",
       "        0.00000000e+00, 0.00000000e+00, 7.04654155e-04, 1.05110911e-01,\n",
       "        3.65832949e-01, 5.13927764e-01, 5.39999967e-01, 5.00891662e-01,\n",
       "        4.50978659e-01, 3.92140037e-01, 3.35297935e-01, 2.85502375e-01,\n",
       "        2.31478890e-01, 1.86968236e-01, 1.51265759e-01, 1.18734225e-01,\n",
       "        9.13701554e-02, 6.34188739e-02, 4.07524986e-02, 2.65419732e-02,\n",
       "        1.70291421e-02, 7.98608042e-03, 3.40582841e-03, 1.29186595e-03,\n",
       "        2.34884718e-04, 1.17442359e-04]),\n",
       " array([-4.17047739, -4.00018108, -3.82988478, -3.65958847, -3.48929216,\n",
       "        -3.31899586, -3.14869955, -2.97840324, -2.80810694, -2.63781063,\n",
       "        -2.46751432, -2.29721802, -2.12692171, -1.9566254 , -1.7863291 ,\n",
       "        -1.61603279, -1.44573648, -1.27544018, -1.10514387, -0.93484756,\n",
       "        -0.76455126, -0.59425495, -0.42395864, -0.25366234, -0.08336603,\n",
       "         0.08693027,  0.25722658,  0.42752289,  0.59781919,  0.7681155 ,\n",
       "         0.93841181,  1.10870811,  1.27900442,  1.44930073,  1.61959703,\n",
       "         1.78989334,  1.96018965,  2.13048595,  2.30078226,  2.47107857,\n",
       "         2.64137487,  2.81167118,  2.98196749,  3.15226379,  3.3225601 ,\n",
       "         3.49285641,  3.66315271,  3.83344902,  4.00374533,  4.17404163,\n",
       "         4.34433794]),\n",
       " [<matplotlib.patches.Polygon at 0x7fc07dad4820>])"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGdCAYAAADAAnMpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy81sbWrAAAACXBIWXMAAA9hAAAPYQGoP6dpAABL/ElEQVR4nO3deXiU5b3/8ffMZDLJZF9IQkLYg2GPLAlo3VoqLl3UU0s3pVTpBq2e/M5ppa3Q0422WqXHUlGrtadqpbUqbW1pMRW1FWUT2UHWbGQn+56Z3x+TmSSQQBIm88zyeV3XXCQzz8x8EzLJZ+77e9+Pyel0OhERERExiNnoAkRERCS0KYyIiIiIoRRGRERExFAKIyIiImIohRERERExlMKIiIiIGEphRERERAylMCIiIiKGCjO6gMFwOByUlpYSExODyWQyuhwREREZBKfTSUNDA+np6ZjNA49/BEQYKS0tJTMz0+gyREREZBiKiooYM2bMgLcHRBiJiYkBXF9MbGyswdWIiIjIYNTX15OZmen5Oz6QgAgj7qmZ2NhYhREREZEAc7EWCzWwioiIiKEURkRERMRQCiMiIiJiKIURERERMZTCiIiIiBhKYUREREQMpTAiIiIihlIYEREREUMpjIiIiIihFEZERETEUAojIiIiYiiFERERETGUwoiIiIgYKiDO2isi4i8Ol9XzSMExIqwW/t/1U0iPjzS6JJGApzAiIjJIFfWtfPrxtznb3AHA7sKz/PXrVxEZbjG4MpHApmkaEZFB+t9/vs/Z5g7GJ9kZFWPjZFUTT/7rhNFliQQ8hRERkUGob+3gDzuLAVh72yy+fdNUAJ5+6xQdXQ4jSxMJeJqmEREZhM37ykjqrGBWUhcLIgvpssEL0SWcbWpnx1tmrpiUBPYkiM80ulSRgKMwIiIyCDv37uVV239jb2qDx12/PJ8BsAEF3RerHVZsVyARGSKFERGRi+jocnCysBC7qY2ia39O5pQcAA6cqecbL+wl2hbGc7fEY3n5S9BcrTAiMkQKIyIiF/FeUS3N7V1gg4ys2ZCeA8BlqQ5KXmmltrmDgx3RzDS2TJGApQZWEZGL2Ha82vOx2WTyfBxmMXNV1igA9hbX+bwukWChMCIichHvFdcOeFvuhEQADpQqjIgMl8KIiMgFOJ1O3rvAqEfueFcYOVLW4KuSRIKOwoiIyAWU17dR2dCGxWzq9/aslGji7VZaO7XXiMhwKYyIiFyAe4pmbEL/56Axm03MG5fow4pEgo/CiIjIBeztDiNZqTEDHjNnXLxvihEJUgojIiIXcPiMqxdkckr0gMfMyoj3UTUiwUlhRETkAo5WuMJIZqJ9wGNmZMR6Pm5o7RzxmkSCjcKIiMgAmts7KappAWBc0sBhJN4ezui4CACOVzb6pDaRYKIwIiIygGMVrmCRFBVOXIT1gse6p3Her1AYERkqhRERkQG8X+4KFlmpA/eLuLnDyLEK7TciMlQKIyIiA3D3i0y5wEoat8mjXGHkeGXTiNYkEowURkREBtAzMnLxMDJhVBQAZ+paaWxTE6vIUOisvSIibrVF0NxzUjxT2R6mm9qYbYmGqvIL3rV3T8mRsgbmjksYsTJFgo3CiIgIuILI+lzoaPZc9SSADXil+wqrHexJF30ohRGRoVEYEREB14hIRzPc9gQkT6GwppmvPLubqHALG7+0ABMmVxCJz7zoQx0uq/dBwSLBQ2FERKS35CmQnsOR2nIOOM8yIzkWU/rlQ3oI966tIjI4amAVEenH6WrXqphxSVFDvu+hsnqcTqe3SxIJWgojIiL9ONUdRsZfYOfV/oSZTTS0dnKmrnUkyhIJSsMKI+vXr2f8+PFERESQl5fH9u3bBzz26aefxmQy9blEREQMu2AREV84Xe1qZB3qyMiYhEhAfSMiQzHkMLJx40by8/NZs2YNu3fvZvbs2SxevJiKiooB7xMbG8uZM2c8l9OnT19S0SIiI809MjIheWhhxB1eDqlvRGTQhhxGHnroIZYvX86yZcuYNm0aGzZswG6389RTTw14H5PJRFpamueSmpp6SUWLiIyk9k4HJWcvfoK8/lxuL2e66SSNp3ZB6Z6eS22R1+sUCRZDWk3T3t7Orl27WLVqlec6s9nMokWL2LZt24D3a2xsZNy4cTgcDubMmcOPfvQjpk+fPuDxbW1ttLW1eT6vr9dwp4j4TtHZZhxOsIdbGBVtG9yd7ElgtfPhQ/fzYRtwGni81+1WO6zYPqilwSKhZkgjI1VVVXR1dZ03spGamkpZWVm/97nssst46qmn2LRpE8888wwOh4MrrriC4uLiAZ9n7dq1xMXFeS6ZmXrxiojv9F5JYzKZBnen+ExYsZ3Kz/yDm9t+yMc7fkT7Xa/BF1937V3S0dxnd1cR6THi+4wsXLiQhQsXej6/4oormDp1Ko899hjf//73+73PqlWryM/P93xeX1+vQCIiPnOqytW8OtSVNMRnkhw3hkJbLQ2tnRwPm8zU0bEjUKFIcBnSyEhycjIWi4Xy8r7naCgvLyctLW1Qj2G1Wrn88ss5duzYgMfYbDZiY2P7XEREfKWwxhVGxg41jODqkbus+8R6R8vVxCoyGEMKI+Hh4cydO5eCggLPdQ6Hg4KCgj6jHxfS1dXFvn37GD169NAqFRHxkeLu5tUxCUMPIwBT0lxh5EiZwojIYAx5miY/P5+lS5cyb948cnNzWbduHU1NTSxbtgyAO++8k4yMDNauXQvA9773PRYsWMDkyZOpra3lgQce4PTp09x9993e/UpERLyk+KxrZMS9Z8hQuUdGFEZEBmfIYWTJkiVUVlayevVqysrKyMnJYfPmzZ6m1sLCQszmngGXs2fPsnz5csrKykhISGDu3Lm89dZbTJs2zXtfhYiIF7mX9WYON4y4R0Y0TSMyKMNqYF25ciUrV67s97atW7f2+fzhhx/m4YcfHs7TiIj4XENbJw1tnQCkx1/ayEjx2RYa2zqJ9lp1IsFJ56YREemlosF1TpmkqHDs4cNbcJgQFU5KjGt/EjWxilycwoiISC8V9a4NFzOGOUXj5p6qOaq+EZGLUhgREenFHUaG27zqNqV7quawwojIRSmMiIj0UtHQPTIyzH4RN8/IiKZpRC5KYUREpJfyhkvbY8RNG5+JDJ7CiIhIL96apslKjcZkgqrGds62dHijNJGgpTAiItKLZ5rmEsOIPTyMsYmu0ZXCqqZLrkskmCmMiIj00tDq2mPkUntGoGeq5lR18yU/lkgwUxgRETlHbEQYMRHWS34cdxPr6RqNjIhciMKIiMg5Rsdd+qgI9CzvPa2REZELUhgRETlHalyEVx4n2z0yUq2REZELURgRETnH6FjvhJHxyVFYLSZaOhxeeTyRYKUwIiJyDm+NjFgtZiaN0mnyRC5GYURE5BxpXhoZgZ4mVhEZmMKIiMg5RntpZAR6mlhFZGAKIyIi50j14shItkZGRC5KYUREBGjv6mkyHamRkQ6H02uPKxJMFEZERICapnbA1XQab7/0Dc/cxiREEml1/ao9U9vitccVCSYKIyIiuE5oB5AcHY7JZPLa45pMJsYlRQFwSvuNiPRLYUREBKhucp0gLyk63OuPPT7JdcI87cQq0j+FERERoKrRFUaSo21ef2yNjIhcmMKIiAhQ3dgBQFKU90dGxiVqZETkQhRGRESA6sZWAJKiRmBkJNk1MlJW30pze6fXH18k0IUZXYCIiD842+waGUkcgZGR+EjX6pxJlFB8YFv/G6HZkyA+0+vPLRIIFEZERICz3Ut7vbms18OeRKvJxs/Dfwmbftn/MVY7rNiuQCIhSWFERASobekAE8SPwMgI8Zk8OuN3vLrzELfmZHD3VRP63l51FF5cDs3VCiMSkhRGRCTktXZ00dzeBTZIGImRESB9bBYHdrSS2JjM3ek5I/IcIoFKDawiEvIqG9o8H0eFW0bkOdx9IkfKGkbk8UUCmcKIiIS8ysaeMGLCe7uv9pbVHUYqGto8/Ski4qIwIiIhr6rXyMhIibaFkZkYCcCRco2OiPSmMCIiIc99XpqRdln36MhRhRGRPhRGRCTkVfpgZATgsjRXGDmsvhGRPhRGRCTkVTX6Joy4m1iPKoyI9KEwIiIhz1dhJDstFnCNjDgcTp88p0ggUBgRkZDnq2maSaOisIWZaWzrpLBGJ80TcVMYEZGQ56uRkTCLmezuvpH9pXU+eU6RQKAwIiIhz1eraQCmpccBcKC03mfPKeLvFEZEJKS1dnTR2Nbps+ebkeHqG1EYEemhMCIiIa2mezfUMPPI7Lx6runukZGSOpxONbGKgMKIiIQ4dxiJjRyZE+SdKzstBovZRHVTO+X1vulVEfF3CiMiEtJqmzsAiI3wzUnMI6wWJo2KAuCAmlhFAIUREQlxNc2ukZEYH4URgBlqYhXpQ2FEREJabXcYiY3wzTQNwLR0dxOrRkZEQGFEREKcr3tGoKeJdX+JRkZEQGFEREKcu2fEl9M07pGRktoWz8iMSChTGBGRkOYZGfHhNE1cpJWxiXYADqpvRERhRERC21kDekYApnePjmhbeBGFEREJcZ4w4sOeEegJI1pRI6IwIiIh7myT73tGoNdOrAojIgojIhLa3CMjcZE+DiPd56g5UdlIS2eXT59bxN8ojIhIyGrt6KK53RUEon3cM5ISE0FqrA2HE45XNPn0uUX8jcKIiIQs97Jei9lEtM3i8+efNSYegPfLG3z+3CL+RGFEREKWe1lvgt2KCd+ctbe32WNcfSNHKxp9/twi/mRYYWT9+vWMHz+eiIgI8vLy2L59+6Du9/zzz2MymbjllluG87QiIl7l3nAswR5uyPNrZETEZchhZOPGjeTn57NmzRp2797N7NmzWbx4MRUVFRe836lTp/iv//ovrrrqqmEXKyLiTTWGhxHXyMiZulZDnl/EXww5jDz00EMsX76cZcuWMW3aNDZs2IDdbuepp54a8D5dXV189rOf5X/+53+YOHHiJRUsIuItZ7t7RuLtvm1edYu3hzMuyW7Ic4v4kyGFkfb2dnbt2sWiRYt6HsBsZtGiRWzbtm3A+33ve98jJSWFu+66a1DP09bWRn19fZ+LiIi31be4wkicjzc86809VSMSyoYURqqqqujq6iI1NbXP9ampqZSVlfV7n3/96188+eSTPPHEE4N+nrVr1xIXF+e5ZGZmDqVMEZFBqW91hRFf777am7uJVSSUjehqmoaGBu644w6eeOIJkpOTB32/VatWUVdX57kUFRWNYJUiEqrqWzoB35+XpjeNjIjAkLYcTE5OxmKxUF5e3uf68vJy0tLSzjv++PHjnDp1io9+9KOe6xwOh+uJw8I4cuQIkyZNOu9+NpsNm802lNJERIbMPTLi691Xe5uREYu5e1VxdVM7SYZVImKcIY2MhIeHM3fuXAoKCjzXORwOCgoKWLhw4XnHZ2dns2/fPvbs2eO5fOxjH+O6665jz549mn4REUO5e0aMnKaxh4cxNtHVxHpUS3wlRA357UB+fj5Lly5l3rx55Obmsm7dOpqamli2bBkAd955JxkZGaxdu5aIiAhmzJjR5/7x8fEA510vIuJr9a3GT9MAZKVEQyPUnNoPEwcYG7EnQbzewElwGnIYWbJkCZWVlaxevZqysjJycnLYvHmzp6m1sLAQs1kbu4qI/2vwg5ERgMzMsTQft3HzsdVwbICDrHZYsV2BRILSsCZKV65cycqVK/u9bevWrRe879NPPz2cpxQR8bo6TxgJA4dxdWRlZbPo7w8wNrKF3y3PO39r+qqj8OJyaK5WGJGgZFzXloiIgZxOZ8/S3ggrNBtXS3ZaLFWWFEpbHBTapjAuKcq4YkQMoPkUEQlJrR0OOrqcgPHTNOFhZqalxwLwbmGtobWIGEFhRERCkntUxGyCqHCLwdXA5WPjAXi38KyxhYgYQGFEREJS72W9JpPpIkePvDljEwDYrZERCUEKIyISkvr0i/gB98jIoTP1tLR3GVuMiI8pjIhISPJsBW/g7qu9ZcRHkhJjo9PhZF9JndHliPiUwoiIhCR/GxkxmUy9pmrUNyKhRWFEREKSu2ckzuCVNL2piVVClcKIiIQkz4ZnfjIyAjBnXE8Tq9PpNLgaEd9RGBGRkOQ5L42f9IwAzMyII8xsorKhjeKzLUaXI+IzCiMiEpLq/XBkJMJq6dn8rKjW2GJEfEhhRERCkqeB1Y96RqDXfiOn1TcioUNhRERCkr8t7XXzNLFqZERCiMKIiISkhu6RkWibf46MHCyto7VDm59JaPCvtwQiIiOttgiaqxndfIROUwupjdFQegaqjhpdGQBjEiJJjrZR1djG/pI65o1PNLokkRGnMCIioaO2CNbnQkczGwBswF973W61gz3JmNq6mUwmLh8bz5aD5bxbWKswIiFBYUREQkdzNXQ0w21PcPsfq2lu7+LxO+aSER/put2eBPGZxtaIa6pmy8Fy7cQqIUNhRERCjiMpi53tUTidYM28HGIijC6pjzndTay7Tp/F6XRi/DmFRUaWGlhFJOS0djpwb3Aa42cNrACzM+OxWkxUNLRRVKPNzyT4KYyISMhpaXetUjGbIMLqf78GI6wWZmbEAbDjVI3B1YiMPP97FYqIjLDm7jASbQvDZPLPSZD53Y2rO08rjEjwUxgRkZDT0u7a8Cza5r9tc+5VNDtOqYlVgp/CiIiEHPc0TXSE/4aRud1n8D1W0Uhd9wZtIsFKYUREQk5T986mUX48MpIYFc7klGgADp1pMLgakZGlMCIiIae1V8+IP5s/vmdreJFgpjAiIiGnuSMwwsi8ca6+kYOl9QZXIjKyFEZEJOQ0dzew+vM0DfSsqHm/otHgSkRGlsKIiISclgCZpslMjCQlxkanw2l0KSIjSmFEREJOoIQRk8nkGR0RCWYKIyIScpoDYGmv27zuJlaRYKYwIiIhp7kjMHpGgD4jI11OTddIcFIYEZGQ09ruACAmAMJIdloMkd3nzymsbja4GpGRoTAiIiHHPU0TCCMjYRYz2aNjATh4Rkt8JTgpjIhIyAmUfUbcpo12ncH3gPYbkSAVGK9EEREvCoQT5fU2fXQM7IGm4gM4S9/FxDlnGrYnQXymIbWJeENgvBJFRLwoEE6U19uUieNpdtr4n66fw+M/P/8Aqx1WbFcgkYAVGK9EEREvavGcKM9icCWDEzlqPF9MfoyS0hLuWZTF9VNTe26sOgovLofmaoURCVgKIyISctwbmgbKNA3ApKyp/KMknL9Xp3F9+myjyxHxKjWwikhIMpsg0hoYIyMAeRNc+428c7La4EpEvE9hRERCkj08DJPJdPED/cS88YlYzCaKz7ZQUttidDkiXqUwIiIhKTI8cEZFwDWlNCPdtd/Ido2OSJBRGBGRkGQPsDACkDcxCYB3TtQYXImIdymMiEhICqR+EbeevhGFEQkuCiMiEpICcWRk3vhETCY4WdVERX2r0eWIeI3CiIiEJHt44CzrdYuLtDI1zdU38rZGRySIKIyISEgKtAZWt7yJ3VM1J9TEKsFDYUREQlIgTtMA5E3obmLVyIgEEYUREQlJgRpGcrubWI9VNFLV2GZwNSLeoTAiIiEp0hp4PSMAiVHhXJYaA8AOjY5IkFAYEZGQFKgjI9Crb0RhRIKEwoiIhKRAbWCFnqmat9XEKkFCYUREQlIgbnrm5g4jR8obaGjtMLgakUunMCIiISmQp2lSYiKYOCoKpxP2l9YbXY7IJVMYEZGQFMjTNNCzxHd/icKIBL5hhZH169czfvx4IiIiyMvLY/v27QMe++KLLzJv3jzi4+OJiooiJyeH3/72t8MuWETEGwJxB9beFnQ3sR4orTO4EpFLN+QwsnHjRvLz81mzZg27d+9m9uzZLF68mIqKin6PT0xM5Nvf/jbbtm1j7969LFu2jGXLlvH3v//9kosXERmuQJ6mgZ6RkeOVjQZXInLphhxGHnroIZYvX86yZcuYNm0aGzZswG6389RTT/V7/LXXXsutt97K1KlTmTRpEvfccw+zZs3iX//61yUXLyIyXIE+TZMWF8G4JDsOp9GViFy6IYWR9vZ2du3axaJFi3oewGxm0aJFbNu27aL3dzqdFBQUcOTIEa6++uoBj2tra6O+vr7PRUTEmwJ9ZARgQffoiEigG1IYqaqqoquri9TU1D7Xp6amUlZWNuD96urqiI6OJjw8nJtvvplHHnmED3/4wwMev3btWuLi4jyXzMzMoZQpInJR9gDdgbW3hZMURiQ4+GQ1TUxMDHv27GHHjh388Ic/JD8/n61btw54/KpVq6irq/NcioqKfFGmiAQ5Jz1zGoE+TQOwYGJPGGlo6zSwEpFLM6S3BsnJyVgsFsrLy/tcX15eTlpa2oD3M5vNTJ48GYCcnBwOHTrE2rVrufbaa/s93mazYbPZhlKaiMhFtXU6iej+OBimadLiIsiIj4QWOFBaz4IJRlckMjxDGhkJDw9n7ty5FBQUeK5zOBwUFBSwcOHCQT+Ow+GgrU1nmxQR32rt6PJ8HBHAO7D2NjMjDoB9xVriK4FryJOm+fn5LF26lHnz5pGbm8u6detoampi2bJlANx5551kZGSwdu1awNX/MW/ePCZNmkRbWxt//etf+e1vf8ujjz7q3a9EROQi3GEk3GLGYjYZXI13zMqMg2Owt7jW6FJEhm3IYWTJkiVUVlayevVqysrKyMnJYfPmzZ6m1sLCQszmngGXpqYmvvrVr1JcXExkZCTZ2dk888wzLFmyxHtfhYjIILR1OgCIsAbP5tPukZGT1U3UNrcTbw83uCKRoRtWO/nKlStZuXJlv7ed25j6gx/8gB/84AfDeRoREa9q6x4ZCZYpGoDE7vDhdMLbJ2q4YcbA/Xsi/ip43h6IiFxEa4drZMQWRGGkt7dPVBtdgsiwBP5CexGRQWrtdI+MBN/7sMmmEqqOvgOlHeffaE+CeO3XJP5LYUREQoa7gTUiLIjCiD0JZ5idn/NLaAQe7+cYqx1WbFcgEb+lMCIiIaPV3cAaFkTTNPGZmFZuZ8WvtnCquplVN2bzgcnJPbdXHYUXl0NztcKI+C2FEREJGW0d7tU0QRRGAOIzSc7K5ZWq0/zjbBofSJ9hdEUiQxJEY5UiIhfmnqaxBWHPiPs8NduOq4lVAk/wvSJFRAYQjEt73fImJGEywfsVjVQ2aIdrCSwKIyISMnqW9gbfr76EqHCy02IBLfGVwBN8r0gRkQG4l/bagqmBtZeF3Wfx3aYwIgFGYUREQkZPA2tw/upbMDER0MiIBJ7gfEWKiPSjpbMTCLKlvb24+0ZOVDZRXt9qdDkig6YwIiIhw90zEqxhJM5uZXq6+kYk8CiMiEjIaA/Cs/aey9M3oiW+EkCC9xUpInKOnn1GgnNkBHr2G9HIiAQShRERCRlBuwNrL/PGJ2I2wanqZs7UtRhdjsigKIyISMjwnCgviKdpYiOszMyIAzRVI4EjeF+RIiLn8OwzEsQjIwALtDW8BBiFEREJGT2raYL7V5+7ifXtkwojEhiC+xUpItKto8tBp8MJBO/SXrd54xOxmE0U1bRQ3qD9RsT/KYyISEhobu/yfByM56bpLdoWxqwxrr6RvcV1BlcjcnHB/YoUEenW0iuMWC0mAyvxDfdUjcKIBAKFEREJCc3tnZ6PTYRAGOluYt1fojAi/k9hRERCQu9pmlAwd1wCVouJioY2o0sRuSiFEREJCS0doRVG7OFhzB4Tb3QZIoOiMCIiIaElxEZGoGeqRsTfKYyISEgItWka6GliBXDiNLASkQtTGBGRkNDS0Xnxg4LMnHEJWM2uX/MltdpvRPyXwoiIhIRQHBmJsFrIHh0NwN7iWmOLEbkAhRERCQmh2DMCMDszAYC9RVriK/5LYUREQkIojowAzO7eifW94locDvWNiH9SGBGRkBCqYSQrNQaA+tZODp6pN7gakf4pjIhISGhpD70GVgCruWe32beOVxlYicjAFEZEJCSE6shIb/8+Vm10CSL9UhgRkZDQHGI7sPZn+8ka2jsdRpchch6FEREJCaG6msYtPtJKS0cXe4pqjS5F5DwKIyISEppDtGfEbWb3qpp/H1PfiPifMKMLEBHxhVAfGbk6oYY3TVWcOVwN05v73mhPgvhMYwoTQWFEREJES0cXVqOLMII9Cax2rnhvFa/YgGrg8XOOsdphxXYFEjGMwoiIhITm9i7ijC7CCPGZrqDRXM0XfrOD8vo2vvux6cwf59qZlaqj8OJyaK5WGBHDKIyISEhoCdUwAq6QEZ/JqCwz/9xZxN+rU5m/cJrRVYl4qIFVREKC9hmBK7OSAfj3ce03Iv5FYUREgp7D4aRF+4xwxaQkAA6dqae6sc3gakR6KIyISNBr7VQQAUiOtpGd5jpXzbYTGh0R/6EwIiJBT1M0Pa6Y1D1Vo63hxY8ojIhI0HPvMWIL06+8Kye7pmp00jzxJ3plikjQc4+MRFgtBldivNwJiVjMJk5XN1N8tvnidxDxAYUREQl67q3gNTICMRFWZndvDf+WpmrET+iVKSJBr0UjI31cOdm9xFdTNeIfFEZEJOj1TNPoVx70bmKtwuF0GlyNiMKIiIQA9x4jmqZxmTMuHnu4harGdk5WNxldjojCiIgEv55pGp0BA8AWZmHhRNeqmt2na40tRgSFEREJAe4GVk3T9Lh6yigAdp8+a3AlIgojIhICmrunaSI0TePhDiMHz9QbXInIMMPI+vXrGT9+PBEREeTl5bF9+/YBj33iiSe46qqrSEhIICEhgUWLFl3weBERb/NseqbVNB7jk+xkJkbS6VADqxhvyGFk48aN5Ofns2bNGnbv3s3s2bNZvHgxFRUV/R6/detWPv3pT/Paa6+xbds2MjMzuf766ykpKbnk4kVEBqPZE0Y0MuJmMpm4OmuU0WWIAMMIIw899BDLly9n2bJlTJs2jQ0bNmC323nqqaf6Pf7ZZ5/lq1/9Kjk5OWRnZ/OrX/0Kh8NBQUHBJRcvIjIYnqW9YWpg7e2aKQoj4h+GFEba29vZtWsXixYt6nkAs5lFixaxbdu2QT1Gc3MzHR0dJCYmDnhMW1sb9fX1fS4iIsPVogbWfi2clESY2QTAmbpWg6uRUDakV2ZVVRVdXV2kpqb2uT41NZWysrJBPcY3v/lN0tPT+wSac61du5a4uDjPJTMzcyhlioj00TMyojDSW0yElezRMQDsLtSqGjGOT1+ZP/7xj3n++ed56aWXiIiIGPC4VatWUVdX57kUFRX5sEoRCTbuTc+0Hfz55oxNABRGxFhDmkBNTk7GYrFQXl7e5/ry8nLS0tIueN8HH3yQH//4x7z66qvMmjXrgsfabDZsNttQShMRGZC2gx/YnLEJsBv2FtfR3ukgXKNHYoAh/dSFh4czd+7cPs2n7mbUhQsXDni/n/70p3z/+99n8+bNzJs3b/jViogMQ7OW9g5o0qgowPU90uiIGGXIETg/P58nnniC3/zmNxw6dIivfOUrNDU1sWzZMgDuvPNOVq1a5Tn+Jz/5Cffffz9PPfUU48ePp6ysjLKyMhobG733VYiIXEBPA6vCyLnMJpPn4zeOVhpYiYSyIa9zW7JkCZWVlaxevZqysjJycnLYvHmzp6m1sLAQs7kn4zz66KO0t7fziU98os/jrFmzhu9+97uXVr2IyCDoRHmDs/VIJd+4IdvoMiQEDWvR/cqVK1m5cmW/t23durXP56dOnRrOU4iIeE1Pz4hGRgZiMsGBM/WU1bWSFjfwAgORkaC3CSIS9DzbwWtkZEDZqa4lvgWHyy9ypIj3aTtCEQlq7Z0Oz/lXIsI0MjKQxal1HCpr4NieOsjs1chqT4J47fUkI0thRESCmntUBDRN0y97EljtfPjw/XzYBpwBHu91u9UOK7YrkMiIUhgRkaDW3OFaSRNmNmG1mC5ydAiKz4QV23E2V/GFp3dS0dDGmo9MI3dCIlQdhReXQ3O1woiMKE2gikhQczevRoZrVGRA8ZmY0i8nc/pCDjgn8HLFKEjPgeQpRlcmIUJhRESCmnuaxq4wclEfzE4B4J+HKnA6nQZXI6FEYUREglqzJ4xoVvpiFkxMItJqoay+lQOlOlu6+I7CiIgEtabu3Vc1MnJxEVYLH8hKBuCfhysMrkZCicKIiAQ1TdMMzaKprqmagkPab0R8R2FERIJaTwOrpmkG47rsFEwmeK+4jsrGNqPLkRChMCIiQc19krwojYwMSkpMBHPHJgDw9okag6uRUKEwIiJBrUlLe4fshhlpAPz7mM7iK76hMCIiQa1ZPSNDtni6K4xoRY34isKIiAS1Fs9qGvWMDFZmop0ZGbE4tNWI+IjCiIgENY2MDM+NM0YbXYKEEIUREQlqCiPD4+4bAWho6zSwEgkFCiMiEtSau6dptLR3aCaNimZsoh2AnafOGlyNBDuFEREJap6REatGRobqyklJAPz7uFbVyMhSGBGRoObegTXKpjAyVFdMdm0Nv/NULQ2tHQZXI8FMYUREglqTdmAdtgnJrmmaji4H/zig7eFl5CiMiEhQa9GJ8obNhMnz8ab3Sg2sRIKdwoiIBDXPuWnUM3JJ/n2siiqdq0ZGiMKIiAS1np4RTdMM15TUaLocTv6674zRpUiQUhgRkaDldDpp0jTNJbs6axQAf9qjqRoZGXqrICJBq63T4dnSXCfKG74PJp9lhrmWlsKTlB02kxYb0XOjPQniM40rToKCwohIIKstgubqgW8P8T8U7ika0D4jw2JPAquduL+t4C/h3dc9f84xVjus2B7SP2dy6RRGRAJVbRGsz4WO5oGPCfE/FO4pmvAwM2EWzUoPWXym6+enuZp/HCzj5wXHGJMQyYbPzXGttKk6Ci8udwXiEP0ZE+9QGBEJVM3VriBy2xOQPOX82/WHwjMyon6RSxCfCfGZXJHUyT2vw4GaLnZ3jGPuuESjK5MgojAiEuiSp0B6jtFV+CVtBe890bYwbpo5mj/uLuYPO4sVRsSrFEZEgsSRsgZe3lPC4TP12MIsfCylgpuMLspgnjCiZb1ecfu8MfxxdzF/fq+U1R+dht3ogiRo6BUq4s8u1KBadRSA8oZWfvL7Pbz0bglOZ8/NRQdPcpMNyupbSUv3Qa1+qFnLer0qb0IiYxPtFNY087d9ZfzHaKMrkmChMCLirwbRoNpujuCT//c+p7tcQ+bXT0vlmstGUdfSwVv/KoFO+OYf9/Lg1y9nVIzNV5X7De2+6l0mk4nb547hZ1uO8vudRfzHRyONLkmChMKIiL8aoEG1paOLP71Xyh93FVPcbKeURBZOTOK+G7OZnRnvOW5JRg08C1WN7Xz9d+/y7N15mM2mfp4oeKmB1fv+Y+4YHnr1KO+crOFMXSIaHBFvUBgRMdpAUzHd0zDuBtXm9k6efbuQR18/Tk2TGRjL1NGxrL0xm6uzkjGZ+gaNpCjXxhDTrGc4eHInr75Wx/VTU3sOCIE9SDzTNOoZ8Zr0+EiuyhrFG0cr+dv+Mr5gdEESFPQKFTHSxaZirHZOt0bymz8f5A+7imhodf1xHZ9k5z8/PIWPzkofeLSje8OqBzp+ATbgze5Lr8cO9j1ImrSaZkR8Lm8sbxyt5B8HyxVGxCsURkSMNMBUTJfTyfaTNWw80MzLjx/zXD8uyc6K6yZz2+UZF9/Eq3vDqo7GKr723G6KzrZwx4JxfGp+ZsjsQaJpmpHxoampZMRH0lDX6Qq6IpdIYUTEH3RPxVQ1trFxRxHPvVNISW0rYMZkgg9elsIdC8dxddaoofV9xGdijc/kpsWpfP1377J2j5WbFk8ndsS+EP/iaWAN1686b7KYTXwmbyx//cdBo0uRIKFXqMhIG+Ty3J+/tI8XdhbT3uUAIMFu5ZPzM/lc3jgyEy9tR4ebZ47mfwve51hFI8+9U8iXsy7p4QJGS4drWitKIyNet2R+Jv941RWM369oICtEl4+LdyiMiIykQS7Pvf3/jlLYlQTA7Mx47lwwjptnjSbCS70OFrOJL141kW/8cS/PvnOaL05OIBTO1NLU5h4ZURjxtuRoGx/ISoZT8MreM9ybY3RFEsgURkRG0gA9IU6cbDlYzq//fYrTzZGUksSVk5O450NTyJ0wMttsf3R2Oj945SBFNS3sKjQxf0Sexb94dmDVNM2IuHnmaDgFrx+tYmlTOwlR4Re9j0h/QuHNkYjx3OePSc+hOnYqd/2jky++2sW2lkyiU8fzu+ULePbuBSMWRMA1OvCJua5m1b/tOzNiz+NP3NM0amAdGVNHxwDQ3uXgue2FBlcjgUxhRMSH3jlRzQ0/f5N/Hq4g3GJm1Y3ZvPL1q1g4Kcknz//pXFcY2XX6rE+ez2iN3dM0UdpnZESY6GmmfvqtU7R1dhlYjQQyhRERH/n9ziI+9+Q7VDa0kZUSzaaVV/KlayZhvdgSXS/KSo0hOy2GTofz4gcHgaa27gZWm0ZGRlJyVDiVDW1s2lNqdCkSoBRGRHzgqbdO8o0X9tLR5eTmWaP508oPMHW0MQtsP5YTOsse3GEkWiMjI8r9M/XkmydxOkMj6Ip3KYyIjCBH9y/mP+4qAeDrH8rikU9dbujqjo/M7AkjDd1/rINVo2dkRGFkJF0/PY2ocAtHyht44/0qo8uRAKQwIjJCOrscPLzlfQDMJvjpJ2aR/+Ephp+sbmyS3bNvyZ7CWkNrGUlOp1MjIz4S03CCe6e3MN10ki2v/h1K9/RcaosMrk4CgV6hIiOgs8vBPc/v4dSRCv6fDf5r8WVcM89/tl2fNzYBDsHO0zVcZXQxI6S1w4G7NUYjIyOk+/xHvLic5cByG1ABPN7rmBA4B5JcOr1CRbzM4XDyzT/u45V9Z5htcY2CXJM1yuCq+po33h1GanE4nIaP1oyExl5TUDpR3gjpPv+Re4fhn/z9CG8creSqrGTuuyE7ZM6BJJdO0zQiXuR0OvneXw7yx93FWMwmvnlDttEl9Wt6ehwAtc3tHK9sNLiakeFZSRNuCcqw5TfiMz176Hzshhs54JzAY+/HcMI6uc9GfyIXojAicqlqizzz48++/Gd2bHuN6aaTPP6hMK6IG+CcNAazWnr+OL9zssbASkaOmld9b+roWD6UnYLTCRteP250ORJA9CoVuRTnnHvmc8Dn3KdUf7P7X6vdNbfup7afrOFzC8YZXYbXqXnVGCs+OJmCwxW8uLuE/5qVRIrRBUlA0KtU5FJ0n3vmnZwf8713XLtP3rlgHEvm95oftyf59Xz59pM1OJ1OTKbgmspoatfIiBHmjE3giklJvHW8mhd3FfNlowuSgKBXqYgX/HCHgwPOCSy/agKfvGkqBMgf9jCzibL6VopqWhibZDe6HK/q2Qpezau+tuK6ybx1vJq/Hyjny/orI4OgHxORi6kt8qwWONfJw7uZAHQ5nNw2J4NVN04NqBGGiaOieK8M3iuuDZ4w0v3/FV5RxnTTSaZRB6WRrtuqjhpbW4i4YlISOZnxtBef1F8ZGZRh/ZisX7+eBx54gLKyMmbPns0jjzxCbm5uv8ceOHCA1atXs2vXLk6fPs3DDz/Mvffeeyk1i/jOOT0h55oANDttTJ00nrX/MSvgVm1kpcRAGewtruWjs4Ngm/he/183ADfYgFLO3/fCj3t4goHJZGLldZN5+LfvAtDQ2kmMwTWJfxtyGNm4cSP5+fls2LCBvLw81q1bx+LFizly5AgpKee3KjU3NzNx4kRuv/12/vM//9MrRYv4THdPCLc90WeZYkVjG//9+/eoamonPT2D/73zRp+e8M5bJqdGA53sLa4zuhTv6PX/9eyJCJ57p5AbZ6Sx8rrJPcf4eQ9PsPhgdgovJ9mhEf6y7wyfnmh0ReLPhvzb86GHHmL58uUsW7aMadOmsWHDBux2O0899VS/x8+fP58HHniAT33qU9hstn6PEfF7yVM8eynUxE3j039u4fXGDNpGzeSnX7jZ0HPNXIopKa73q/tL6ugKpjP5Jk/hlHUyB5wTqE+Y7vm/Iz1HQcRHzGYTt3c3cm/aU9JnEzqRcw0pjLS3t7Nr1y4WLVrU8wBmM4sWLWLbtm1eK6qtrY36+vo+FxF/0NzeyRee3sHxyiZGx0Xwf1/IJSEq3Oiyhm1MQiSRVgtN7V2crAquzc88DazhalowygcmJwOuaZpn3j5tcDXiz4YURqqqqujq6iI1NbXP9ampqZSVlXmtqLVr1xIXF+e5ZGbqnYwYr73TwVef3c2eolri7VZ+e1cu6fGRRpd1ScLMJmZkxALwXlGQTNV08+zAqtU0hrH0auZ+4o0TNLdrdET655eT3KtWraKurs5zKSrSWR/FWJ0OJ/+5cQ9bj1QSYTXz5NL5TE4Jjpa8GRmureH3lwZnGNGmZ8YbHRdBdVM7z75daHQp4qeGFEaSk5OxWCyUl5f3ub68vJy0tDSvFWWz2YiNje1zETHSzwve55V9Z7BaTGz43FzmjkswuiSvmTra9fo6UtZgcCXepe3g/ccnu89Y/dgbJ2hp7zK4GvFHQwoj4eHhzJ07l4KCAs91DoeDgoICFi5c6PXiRIzmxNXU+c/DFVjMJn7xmTlce1kQbXBddZTLw04z3XQSZ+kenKXves6zQ21gj0i6d2DVyIjxrstOISM+kqrGNn63XaMjcr4hv0rz8/NZunQp8+bNIzc3l3Xr1tHU1MSyZcsAuPPOO8nIyGDt2rWAq+n14MGDno9LSkrYs2cP0dHRTJ48ecDnETGa0+nkyTdPcTeuDVUf+uRsFk/33gigoexJrv02XlxOFvCKDXBy/n4cK7YH7OqTJs8OrAojRrPWvM935sTyi9dO8trWEj47dh62sO73wlpqLQwjjCxZsoTKykpWr15NWVkZOTk5bN682dPUWlhYiNncM+BSWlrK5Zdf7vn8wQcf5MEHH+Saa65h69atl/4ViIyQh7ccpWBPCXfb4OsfnMz1ORlGl+Q98ZmuoNG9s+yXntlF8dkWvv/x6cwZm+DaqfTF5a7bA/QPRaMaWI3XK/TeCNxoAzqB3jtBBHjoFe8Y1luGlStXsnLlyn5vOzdgjB8/HqcziPYvkODUa8t3J06efaeIgu2FTDaVAHD9tCAZEektPtPzByAsw8GBmjPsbB/LnPRJBhfmHWpg9QPnhN5X9p3hl1uPkxwVzhNL5xF+9ljAh17xDr1KRc7Z8t0EfA74nHuPvhDYPvyytBhe2XeGw0HSxOpwOmlu1zSNX+gVehelzOT7u7ZyoL6V35ck8bmxfrmgUwygV6lI9xbizlsf58nD4by0xzUacvcHJnDr5RkhMaedneZapnz4THCEkZaOnhUbGhnxH7YwC1++ZiLf/fNBHt16nCWfTcBqdFHiFxRLRbo9djCMH7wbzgHnBJZ87CPcevPNIbN9uHt577GKRjq6HAZXc+la2l1fQ5jZ1NMoKX7hU7ljGRVjo6S2hYLD5Re/g4QEvWWQ0NCrJ+RcjsojmIE/7z2DyTSBH906k0/njvVtfQbLiI8kKty1LfypqiayjC7oErl3+oyJCMNkCqwzKQe7CKuFL109kR+8cojf7yziBqMLEr+gMCLB75yekHOZgWanjVpieOATs/nE3DG+rc8PmM0mLkuLYXdhLYfKGsgaZXRFl8a9kiY2UpMA/uizeePY8PpxyuvbQOdPFRRGJBT0Oq08yVM8V3c6nKx79SivHamkzhTLN5Z8iI8H0/LdIbosLZbdhbUcKauHAA8j7j1GYiMURvxRZLiFL149kU1/OwK4Xov6YxTaNJkqoSN5iuc08h2ps7j3DSe/OBzNEdNEVn3qwyEdRCC4mljdu6/GRupPnL/6bN444rpHrv6p3pGQpzAiIae908HK53bzl72uc82s/+wcbp412uiyDOcJI0GwvNe9x4hGRvxXlC3MMyX6zNuFOmdNiFMYkZDS2tHFl5/Zxd8PlBMeZuaxO+YGzxbvlyg7zbWipqS2xTOyEKjcPSMxERoZ8Wc3z3S9Cahuauepf580uBoxksKIhIzWTgfL/28n/zxcQYTVzJNL5/HB7FSjy/IbcXYrabERAJyu7r/ZN1B4pmk0MuLXei+7fnTrcaob2wysRoykMCIh47t/2s+b71dhD7fw68/nclWgLxkZAZd1T9WcqmoyuJJL09Ta3cCq1TQBYdKoKBrbOnnkn8eMLkUMojAiQa+x+13yvpJ6om1h/N8Xclk4Kbi3dx+u7NHdYSRoRkY0TRMIvnDlBACeefs075cHfs+SDJ3CiAS1uuYO7n95PwDRNgvP3J3HvPGJBlflv9xNrKerA3tkRPuMBJaczHgWTU2h0+Hk/k37dXLVEKQwIkGrpqmdz/zqbY6WNwLww1tnkpMZb2xRfu6yVFcT66kADyPufUZi1DMSMNZ8dDoRVjNvn6jhT++VGl2O+JjGMCU4nLPde31rB6tf2g9VTVweWQ4OmDwq2sACA8OklCgsZhONbV0BvTNmc5umaQJNZqKdlddN5sF/HOX7fznEddkpakAOIXqlSuDrZ7v3WOAX4PqD6gCsdtfZd+WCbGEWJo2KggqjK7k0je1qYA0oVUcB+GKWg307yiiubeHZl+r4yo25IXGiSlEYkWDQa7v3huiJ3L9pP+9XNBJvt/Lj22aSmdAdRPRLbVAuS4vlRICHkSb1jAQGe5LrjcKLywEIBx4D15uII9B1PBLLyh167YYAhREJGg0xE/nsK63sLR9FUlQGD39xAZmpMUaXFXCy02I4sdfoKi5NR5cD0DSN34vPhBXbzzuj9s/+cZTCo+/yc35JR2MVVoWRoKcGVgkaqzcdYG9xHYlR4Ty7PI8pCiLDclmQfN/MJogKVxjxe/GZnnNGuS+f/8THqbCNA+Dld0uMrE58RGFEAp57T4mj5Q3E2608c1eeZ2tzGTr3XiMAHY7AXWIZbQvDbDYZXYYMQ1K0jbs+MBGA594pDPil5nJxCiMS0BrbOlnzp4OA64/Ps3fnMS1dQeRSZMRHEhVuAaD4bOBufqZlvYHtQ1NdOyS3dzn4zsvaeyTYKYxIwGpq62TZr7dz6Ew9AD+8dQbT0+MMrirwmUwmxidHAXCyMnDfkSZEKYwEMhOuUS2rxcyb71exaY/2HglmCiMSkJrbO1n29A52nDpLlM31Ll77iHjPpO7v5fuVjQZXMnwJ9nCjSxAv+PR8V/Pq9/9ykLNN7QZXIyNFYUQCR20RlO6htXA333v8OZpO7SLXVsgD10QYXVnQcQe74xWBG0YSoxRGgsFtc8YwJTWa6qZ2fvTXQ0aXIyNEreYSGHptbBYB/Bh6dgh9HW1q5mWTU7vDSGUjXQ4nlgBsBNXISHCwWkz86NaZfGLDNv6wq5jb5ozRiS6DkMKIBIbujc1+mfhNXjkTS6TVzPc+PoNpo7ubVbWpmVeNSYgEoLXDwcmqRianBN5yX4WR4DFvfCKfzRvLs+8U8u2X9vHXe64iwmoxuizxIk3TSEDo6HJ10r9yJpaT1snc94VPMW3u1T17EyiIeJXF1DMSsq+kzsBKhk8NrMHlGzdkMyrGxomqJn659bjR5YiXKYyI3+tyOPnZliMA2MLM/Prz85k3PtHgqkLHvuJ6o0sYFo2MBJe4SCvf/eh0AB57/TgltS0GVyTepDAifs3pdPKdl/fz5vtVAHzrpqnkTdR8sS/tD9CRETWwBp+bZqaRNyGRtk4HP/nbYaPLES9SGBG/9vOC9/nd9kLc/ZPzxiUYW1AIOlBahyMAd2KNt2uaJtiYTCbu/8g0TCb403ul7Dp91uiSxEvUwCr+o7aozwmz3ni/ki0FR5huglXzzBDgJ28LRLYwM03tXZyoamJyiv/v4+LEibvbRSMjQaLqaJ9PZ5jgnqlNbDlUwaOb2nl85S3a9j8IKIyIf+i1dNftauBq9/LdvWj5rgEmJkex+4xrqiYQwkhTexfuKtUzEuDsSa7X/IvLz7vpXuBeGzRX23j17TFcf8V8n5cn3qUwIv6he+kutz1BZcQ48n//HjVN7cwfn8j9H5nqWt2h5bs+NzklGs64VtTccnmG0eVcVH1zB9FARJhZSz8DXXwmrNjeZ7S0t3+8/gbXH7mfZ197l6vmzSEyXP/fgUxhRPxKW/xk7vpTM3sbM7gsNYb/t3QhFp3wzDCTUqKBzoBZ3nu2pYN0IEFTNMEhPnPANyDXXOmAI1DV2M4Tb57g6x/K8nFx4k1qYBW/8ot/HmNvcR3xdiu/WjpPZ141mHtq5mBpfUA0sdY0us5doima4GcL6/nz9ejW45TXtxpYjVwqhRHxK/88UoHFbOKXn5lDZqLd6HJCXmaCnQirmca2Tk5W+/8ZfM82u8OIQmyomJoWQ0tHFz/dfMToUuQSKIyIX3jj/UrPx/ffPJUrJicbWI24hZlNTO3ecj8Q9htxhxGtpAkdd181EYA/7i5mX7H//4xK/xRGxHe6z7p77mXHttfYtOU1AG6akcbSK8YbV6OcZ2ZGHEBA/KJ3n2JePSOhIzsthlty0gH4/l8O4nT6/3SinE8NrOIb/SzddZsPzA+DNlMEX74pF5NJewb4kxnuMBIAIyM1zR2AekZCzTduyGbzgTK2n6rhT++V8vEc/1/5JX0pjIhv9Fq6S/IUOhxOnnn7NC/sKgbgqsnJ5N+yEEvCWIMLlXPNSHeFkQPdTaz+vMGURkZCUNVR0pNh9dxOnn3nNM++VEiu7XJGx0VoO4AAojAivpU8hf3OCXz75X28V2wFJnD3BybwjZum+vUfuVCWlRpNeJirifVUdRMTR/nv5mc9IyNqYA1652yK9hngM+5NEjd2/2u1u/YqUSDxewoj4lO/eO0YP9tfgtMJsRFh/OQ/ZnHjzNFGlyUXYLWYmTo6lveKatlXUue3YaSzy0F9SzuEq4E1JPSzKVpFYxtf/927NLR2cltmI3dVrHXdrjDi9xRGxLvOOb8MQHlDK2/8+9/cDvxtfxlO5wQ+Ojudb92Uzei4SGPqlCGZlRHHe0W17C2u89v5+PKGNtxbocRFamQkJJyzKVoK8J93TOaOJ7fzYtFx7rJBl9OJ9mb1fwoj4j0DNKmmArcDzU4bGekZrL55AXkTdY6ZgNB9krLr4irYbTrJ2WNVUOrqy/C3+fjimp6fO4uaoENW3sQkHloym8eePw7AA38/wr13zNLpAfycwoh4T3eTasHUH7DxlJ2S2hbPTTmZcdz6gdk8NnOmVssEgnPm4z8IfNAG1AKPdx/jZ/PxvX/eJLR9ZFY6cWez4TV48/0qXv/lW6z/zOV+O8UoCiPiBQ6Hk20nqnnzjcPcBzy0x8QBZxpR4RY+fnkGn8kd61keKgHinPl4J04+/9QOqpraWXvrTGZFlLuCih/NxxefVRiRHldlJcNrcHlEOe+WvUf+un3cOiedWy/PIDbC6ncje6FOYUSGpldPyNmWDl49WM7fD5Rxpq6VyaYSCIcpqdHcccVMPjo7nSibfsQCVq/5eBMQM9HJ63vP8HpDOrPG+F+4LD57/h42EsK6R/d+0PG/4F5ls6/7AnRZIun48jtEjBpnVIXSi/5SyODVFuFcn4upuyckAVcvyO3gebE7wiJ5+PMf0juOIDR/fCJ/2XuGHafPwlT/W61yqkphRHrpNbrnxMm2EzU8v72Q45VNTDaV8PPwX3Lbur/gHD2b2WPimTUmjtmZ8UwaFY1F2wz4nMKIDMrp6ibeeH03d3Q0c0/7VznmdK2omJIaw40z0rgqK5lIqwWzhj6D1vzxiQDsPFVDe1cK/hRHnE4nh8vq0U+e9NE9umcCrkiHhVc62Xn6LLu2vQZHoNPh5EBxHXt7neogKtzC9Iw4ZmXE8YGsZBZOSsIWpubXkaYwIn31moZp6ejizWNVFBwsZ39pPZNNJdwRDmXhY5mTcxWfys1kerr/DdfLyJg6OobUWBvl9W3sL61jjtEF9XKmrpX61k7CLHpHKwMzmUzMH5/I/PBJcASeXDqPne1j2VNYy96SOvaX1NHU3sX2kzVsP1nDr/51kmhbGB+amsKS+ZksnJikBvwRojAiPc6ZhokEru++uKdhOi2R/OYrNxGRrHnWUGMymbjushSe31HEjpNn/SqMHC6rB2BMQiQ0GlyMBIy02Ag+kp7OR2a5TrTX5XByvLKRvcV17DpdQ8GhCioa2ti0p5RNe0qZmBzFp3IzufXyMYyKsV3k0WUoFEaEoppmXj9ayYm9/2b1OdMwY+Ij+eDUFD6YncKoaBth9iTCNA0Tsj6Y7Qoj/z5WxZeMLqaX3adrAVxLNxVGZLC699FxswBTgCmj4ROTknDcMpM9xbW8sKuYTe+WcKKqiR/99TA/2XyE6y4bxSfmjuGaKSlEhmsa51IpjIQYp9NJUU0L7xadZeeps/zrWBUnq5oAmG6qARu0xE3m2pwruWnmaKaNjtWwpHhcc9kokqPDqW5q71mh4Ae2nXBNLc4cEwcnDS5G/N85++j0y2rHvGI7c8ZmMmdsAt+6aSp/2lPK73cWsaeollcPVfDqoQrCw8wsmJjEtVNGceXkZLJSonWerWEYVhhZv349DzzwAGVlZcyePZtHHnmE3NzcAY//wx/+wP3338+pU6fIysriJz/5CTfddNOwi5bBae3o4lhFI8cqGjla3sCRsgb2FNViayolwdQAgB2YbTGRnRbDB5MdcAQeu2MOpvRsY4sXv2QLs/C5BePYUnAEgIbWDmIMrqmhtYP3imoB17b1IhfVz3lt+qg6et4+OtG2MD6TN5bP5I3l/fIGXthdzF/eO0NJbQtvHK3kjaOVroe2W5k/PpG8CYnkTkhk2uhYwixmX31lAWvIYWTjxo3k5+ezYcMG8vLyWLduHYsXL+bIkSOkpKScd/xbb73Fpz/9adauXctHPvIRnnvuOW655RZ2797NjBkzvPJFhJouh5O6lg7ONrdT29xOTVMHFQ2tlNa2UHK2hdLaVkpqWzhT1+I5V4dbOlW8avtv7Ka2vjdUd1+sdkz2ZF99KRKAvnT1JPbvjIBW+Mnfj/DdL8419JftPw6U0+lwMnFUFKmxfjRcI/7tnPPa9OucaRy3LGDVwiTuu+E6jlU08tqRCt44WsWu02epbe5gy8FythwsB1yrc+aMS2B6ehxTR8cwdXQsE5KjsCqg9GFyOp3Oix/WIy8vj/nz5/OLX/wCAIfDQWZmJl/72te47777zjt+yZIlNDU18Ze//MVz3YIFC8jJyWHDhg2Des76+nri4uKoq6sjNjZ2KOX6jc4uB22dDto7Xf+2dXbR1umgsa2TxtZOz78N7n9bO2hs6/m8vrWD2mZXAKlr6eBC/2vpVHlGPmIiwhiXaCcz0c7YRDszbeVMeONeuO0JSJ5y/p21NFcG4cTefzPxxZu4ue2HzMm7lm/fPNWQc3+0dXbxqcff5t3CWu5dlMW905rh8Wvgi69Deo7P65EgMcB5tvqw2mHJb6HXm7cOh5PjFY3sqjTxzzM2tp+qoaG187y7mk2QGhvBmIRIMuIjSY62kRAVTlyklQR7OPF2K1G2MCKtFiKsZiKtFmxWC5FWC1aLKaCmzgf793tIIyPt7e3s2rWLVatWea4zm80sWrSIbdu29Xufbdu2kZ+f3+e6xYsX8/LLLw/4PG1tbbS19bxzr6tzrQGvr68fSrkXtWbTfo6WN+LE1UvhdIID17/uz504cTjB4foEh9OJk+5/nbju4+x1n+7HcjhdIxjtXa4Acu4IxcUkU0uyqbbPdVZcZ6V0jz9F2SzE2KxER4QRF2ElOdpGZkQLiw59hzBHq+ugVqC0+9KtPiwS4mdA9Jj+n9zL32cJPsk2J/VtTjLbT/DOGy18/M2/ERdhxR4eRqTNQpj5Iu/6LvIeaDAvly6Hk7K6Fprau5gZZuZjiZHUnyqENic0NOrnWIbPHAd3FEBLTf+3N1fDi1+EJ28776Z0ID0sko/e9jhdCxIprGnh/bIGTlU3cbqmmdM1TbS2O6AS6iqh7vxHp8oZTxXx/T61xWwizGIizGzCYur+132xmAgzm13HmF2hxWRy7aAMYML1uetjcH9i6v7wfz42naxU7068uv9uX3TcwzkEJSUlTsD51ltv9bn+v//7v525ubn93sdqtTqfe+65PtetX7/emZKSMuDzrFmzxonr95Euuuiiiy666BLgl6KiogvmC79cTbNq1ao+oykOh4OamhqSkrThTH19PZmZmRQVFQXslFWg0PfaN/R99g19n31D3+e+nE4nDQ0NpKenX/C4IYWR5ORkLBYL5eXlfa4vLy8nLS2t3/ukpaUN6XgAm82Gzda3ES0+Pn4opQa92NhY/aD7iL7XvqHvs2/o++wb+j73iIuLu+gxQ2rnDQ8PZ+7cuRQUFHiuczgcFBQUsHDhwn7vs3Dhwj7HA2zZsmXA40VERCS0DHmaJj8/n6VLlzJv3jxyc3NZt24dTU1NLFu2DIA777yTjIwM1q5dC8A999zDNddcw89+9jNuvvlmnn/+eXbu3Mnjjz/u3a9EREREAtKQw8iSJUuorKxk9erVlJWVkZOTw+bNm0lNTQWgsLAQc69O+iuuuILnnnuO73znO3zrW98iKyuLl19+WXuMDJPNZmPNmjXnTWOJ9+l77Rv6PvuGvs++oe/z8Ax5nxERERERb9IWcCIiImIohRERERExlMKIiIiIGEphRERERAylMBIk2trayMnJwWQysWfPHqPLCSqnTp3irrvuYsKECURGRjJp0iTWrFlDe3u70aUFvPXr1zN+/HgiIiLIy8tj+/btRpcUdNauXcv8+fOJiYkhJSWFW265hSNHjhhdVlD78Y9/jMlk4t577zW6lIChMBIkvvGNb1x0u10ZnsOHD+NwOHjsscc4cOAADz/8MBs2bOBb3/qW0aUFtI0bN5Kfn8+aNWvYvXs3s2fPZvHixVRUVBhdWlB5/fXXWbFiBW+//TZbtmyho6OD66+/nqamJqNLC0o7duzgscceY9asWUaXElgGc4I88W9//etfndnZ2c4DBw44Aee7775rdElB76c//alzwoQJRpcR0HJzc50rVqzwfN7V1eVMT093rl271sCqgl9FRYUTcL7++utGlxJ0GhoanFlZWc4tW7Y4r7nmGuc999xjdEkBQyMjAa68vJzly5fz29/+FrvdbnQ5IaOuro7ExESjywhY7e3t7Nq1i0WLFnmuM5vNLFq0iG3bthlYWfCrq3OdtF4/v963YsUKbr755j4/1zI4fnnWXhkcp9PJ5z//eb785S8zb948Tp06ZXRJIeHYsWM88sgjPPjgg0aXErCqqqro6ury7NzslpqayuHDhw2qKvg5HA7uvfderrzySu2C7WXPP/88u3fvZseOHUaXEpA0MuKH7rvvPkwm0wUvhw8f5pFHHqGhoYFVq1YZXXJAGuz3ubeSkhJuuOEGbr/9dpYvX25Q5SLDs2LFCvbv38/zzz9vdClBpaioiHvuuYdnn32WiIgIo8sJSNoO3g9VVlZSXV19wWMmTpzIJz/5Sf785z9jMpk813d1dWGxWPjsZz/Lb37zm5EuNaAN9vscHh4OQGlpKddeey0LFizg6aef7nMOJhma9vZ27HY7L7zwArfccovn+qVLl1JbW8umTZuMKy5IrVy5kk2bNvHGG28wYcIEo8sJKi+//DK33norFovFc11XVxcmkwmz2UxbW1uf2+R8CiMBrLCwkPr6es/npaWlLF68mBdeeIG8vDzGjBljYHXBpaSkhOuuu465c+fyzDPP6BeLF+Tl5ZGbm8sjjzwCuKYQxo4dy8qVK7nvvvsMri54OJ1Ovva1r/HSSy+xdetWsrKyjC4p6DQ0NHD69Ok+1y1btozs7Gy++c1vakpsENQzEsDGjh3b5/Po6GgAJk2apCDiRSUlJVx77bWMGzeOBx98kMrKSs9taWlpBlYW2PLz81m6dCnz5s0jNzeXdevW0dTUxLJly4wuLaisWLGC5557jk2bNhETE0NZWRkAcXFxREZGGlxdcIiJiTkvcERFRZGUlKQgMkgKIyIXsWXLFo4dO8axY8fOC3kaWBy+JUuWUFlZyerVqykrKyMnJ4fNmzef19Qql+bRRx8F4Nprr+1z/a9//Ws+//nP+74gkX5omkZEREQMpQ48ERERMZTCiIiIiBhKYUREREQMpTAiIiIihlIYEREREUMpjIiIiIihFEZERETEUAojIiIiYiiFERERETGUwoiIiIgYSmFEREREDKUwIiIiIob6/1py1hYddoOmAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "\n",
    "plt.plot(np.linspace(-5, 5, 2000).astype(np.float32), np.exp(net_out))\n",
    "\n",
    "plt.hist(\n",
    "    sim_out[\"rts\"] * sim_out[\"choices\"],\n",
    "    bins=50,\n",
    "    histtype=\"step\",\n",
    "    fill=None,\n",
    "    density=True,\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
